{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distribution Schema Tutorial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial will walk you through using the distribution schema system. The boston housing dataset is used for the tutorial.\n",
    "There is also a section which provides guidance for future contributions to the distribution code base.\n",
    "\n",
    "**NOTE**: Pytorch is required for this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(os.path.join(\"../../..\"))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Taking a look at the Boston Housing Data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".. _boston_dataset:\n",
      "\n",
      "Boston house prices dataset\n",
      "---------------------------\n",
      "\n",
      "**Data Set Characteristics:**  \n",
      "\n",
      "    :Number of Instances: 506 \n",
      "\n",
      "    :Number of Attributes: 13 numeric/categorical predictive. Median Value (attribute 14) is usually the target.\n",
      "\n",
      "    :Attribute Information (in order):\n",
      "        - CRIM     per capita crime rate by town\n",
      "        - ZN       proportion of residential land zoned for lots over 25,000 sq.ft.\n",
      "        - INDUS    proportion of non-retail business acres per town\n",
      "        - CHAS     Charles River dummy variable (= 1 if tract bounds river; 0 otherwise)\n",
      "        - NOX      nitric oxides concentration (parts per 10 million)\n",
      "        - RM       average number of rooms per dwelling\n",
      "        - AGE      proportion of owner-occupied units built prior to 1940\n",
      "        - DIS      weighted distances to five Boston employment centres\n",
      "        - RAD      index of accessibility to radial highways\n",
      "        - TAX      full-value property-tax rate per $10,000\n",
      "        - PTRATIO  pupil-teacher ratio by town\n",
      "        - B        1000(Bk - 0.63)^2 where Bk is the proportion of black people by town\n",
      "        - LSTAT    % lower status of the population\n",
      "        - MEDV     Median value of owner-occupied homes in $1000's\n",
      "\n",
      "    :Missing Attribute Values: None\n",
      "\n",
      "    :Creator: Harrison, D. and Rubinfeld, D.L.\n",
      "\n",
      "This is a copy of UCI ML housing dataset.\n",
      "https://archive.ics.uci.edu/ml/machine-learning-databases/housing/\n",
      "\n",
      "\n",
      "This dataset was taken from the StatLib library which is maintained at Carnegie Mellon University.\n",
      "\n",
      "The Boston house-price data of Harrison, D. and Rubinfeld, D.L. 'Hedonic\n",
      "prices and the demand for clean air', J. Environ. Economics & Management,\n",
      "vol.5, 81-102, 1978.   Used in Belsley, Kuh & Welsch, 'Regression diagnostics\n",
      "...', Wiley, 1980.   N.B. Various transformations are used in the table on\n",
      "pages 244-261 of the latter.\n",
      "\n",
      "The Boston house-price data has been used in many machine learning papers that address regression\n",
      "problems.   \n",
      "     \n",
      ".. topic:: References\n",
      "\n",
      "   - Belsley, Kuh & Welsch, 'Regression diagnostics: Identifying Influential Data and Sources of Collinearity', Wiley, 1980. 244-261.\n",
      "   - Quinlan,R. (1993). Combining Instance-Based and Model-Based Learning. In Proceedings on the Tenth International Conference of Machine Learning, 236-243, University of Massachusetts, Amherst. Morgan Kaufmann.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import load_boston\n",
    "data = load_boston()\n",
    "print(data[\"DESCR\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>CRIM</th>\n",
       "      <th>ZN</th>\n",
       "      <th>INDUS</th>\n",
       "      <th>CHAS</th>\n",
       "      <th>NOX</th>\n",
       "      <th>RM</th>\n",
       "      <th>AGE</th>\n",
       "      <th>DIS</th>\n",
       "      <th>RAD</th>\n",
       "      <th>TAX</th>\n",
       "      <th>PTRATIO</th>\n",
       "      <th>B</th>\n",
       "      <th>LSTAT</th>\n",
       "      <th>MEDV</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.00632</td>\n",
       "      <td>18.0</td>\n",
       "      <td>2.31</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.538</td>\n",
       "      <td>6.575</td>\n",
       "      <td>65.2</td>\n",
       "      <td>4.0900</td>\n",
       "      <td>1.0</td>\n",
       "      <td>296.0</td>\n",
       "      <td>15.3</td>\n",
       "      <td>396.90</td>\n",
       "      <td>4.98</td>\n",
       "      <td>24.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.02731</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.07</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.469</td>\n",
       "      <td>6.421</td>\n",
       "      <td>78.9</td>\n",
       "      <td>4.9671</td>\n",
       "      <td>2.0</td>\n",
       "      <td>242.0</td>\n",
       "      <td>17.8</td>\n",
       "      <td>396.90</td>\n",
       "      <td>9.14</td>\n",
       "      <td>21.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.02729</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.07</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.469</td>\n",
       "      <td>7.185</td>\n",
       "      <td>61.1</td>\n",
       "      <td>4.9671</td>\n",
       "      <td>2.0</td>\n",
       "      <td>242.0</td>\n",
       "      <td>17.8</td>\n",
       "      <td>392.83</td>\n",
       "      <td>4.03</td>\n",
       "      <td>34.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.03237</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.18</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.458</td>\n",
       "      <td>6.998</td>\n",
       "      <td>45.8</td>\n",
       "      <td>6.0622</td>\n",
       "      <td>3.0</td>\n",
       "      <td>222.0</td>\n",
       "      <td>18.7</td>\n",
       "      <td>394.63</td>\n",
       "      <td>2.94</td>\n",
       "      <td>33.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.06905</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.18</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.458</td>\n",
       "      <td>7.147</td>\n",
       "      <td>54.2</td>\n",
       "      <td>6.0622</td>\n",
       "      <td>3.0</td>\n",
       "      <td>222.0</td>\n",
       "      <td>18.7</td>\n",
       "      <td>396.90</td>\n",
       "      <td>5.33</td>\n",
       "      <td>36.2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      CRIM    ZN  INDUS  CHAS    NOX     RM   AGE     DIS  RAD    TAX  \\\n",
       "0  0.00632  18.0   2.31   0.0  0.538  6.575  65.2  4.0900  1.0  296.0   \n",
       "1  0.02731   0.0   7.07   0.0  0.469  6.421  78.9  4.9671  2.0  242.0   \n",
       "2  0.02729   0.0   7.07   0.0  0.469  7.185  61.1  4.9671  2.0  242.0   \n",
       "3  0.03237   0.0   2.18   0.0  0.458  6.998  45.8  6.0622  3.0  222.0   \n",
       "4  0.06905   0.0   2.18   0.0  0.458  7.147  54.2  6.0622  3.0  222.0   \n",
       "\n",
       "   PTRATIO       B  LSTAT  MEDV  \n",
       "0     15.3  396.90   4.98  24.0  \n",
       "1     17.8  396.90   9.14  21.6  \n",
       "2     17.8  392.83   4.03  34.7  \n",
       "3     18.7  394.63   2.94  33.4  \n",
       "4     18.7  396.90   5.33  36.2  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame(data.data, columns=data[\"feature_names\"])\n",
    "df[\"MEDV\"] = data.target\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Taking a look at the data below, we can naturally group the data into various distributions that may fit the data better:\n",
    "\n",
    "Generic Continuous:\n",
    "\n",
    "- CRIM\n",
    "\n",
    "- NOX\n",
    "\n",
    "- RM\n",
    "\n",
    "- DIS\n",
    "\n",
    "- TAX\n",
    "\n",
    "- MEDV\n",
    "\n",
    "Proportion (0, 1 bounded):\n",
    "\n",
    "- ZN\n",
    "\n",
    "- INDUS\n",
    "\n",
    "- AGE\n",
    "\n",
    "- PTRATIO\n",
    "\n",
    "- LSTAT\n",
    "\n",
    "Binary:\n",
    "\n",
    "- CHAS\n",
    "\n",
    "Categorical:\n",
    "\n",
    "- RAD\n",
    "\n",
    "We can use a schema to specify distribution types so that the fit is more accurate.\n",
    "Often the normal distribution is taken as a good approximation for generic continuous data.\n",
    "Note that the distributional assumptions must be *conditional* on the causal variables X:\n",
    "\n",
    "$$ p(y | X ; \\theta)  \\sim Dist $$\n",
    "\n",
    "\n",
    "The schema is just a dictionary object which maps the column name to the string alias distribution type:\n",
    "\n",
    "`{col\\_name : dist\\_type\\_alias}`\n",
    "\n",
    "A list of the supported distribution types and their string aliases can be found in\n",
    "`causalnex/structure/pytorch/dist_type/__init__.py`.\n",
    "More distributions allow causalnex to be used on a wider variety of datasets. Please see the `Developer Guide` section below if you are interested in helping out!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-7-ac4fd647dd8c>:9: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_subset[cont_cols] = (df_subset[cont_cols] - df_subset[cont_cols].mean())  / df_subset[cont_cols].std()\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"600px\"\n",
       "            src=\"supporting_files/02_boston_housing.html\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "            \n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x7fd2632ff220>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# currently causalnex only supports continuous and binary data\n",
    "cont_cols = ['CRIM', 'NOX', 'RM', 'DIS', 'TAX', 'MEDV']\n",
    "bin_cols = ['CHAS']\n",
    "\n",
    "# subset data for which there is an avaliable schema\n",
    "all_cols = cont_cols + bin_cols\n",
    "df_subset = df[all_cols]\n",
    "# current gaussian distribution assumes unit variance\n",
    "df_subset[cont_cols] = (df_subset[cont_cols] - df_subset[cont_cols].mean())  / df_subset[cont_cols].std()\n",
    "\n",
    "# insert into schema as colname:dist_type\n",
    "schema = {}\n",
    "for col in bin_cols:\n",
    "    schema[col] = \"bin\"\n",
    "for col in cont_cols:\n",
    "    schema[col] = \"cont\"\n",
    "\n",
    "# NOTE: only the pytorch version supports multiple distribution types at the moment\n",
    "from causalnex.structure.pytorch import from_pandas\n",
    "sm = from_pandas(df_subset, dist_type_schema=schema, lasso_beta=1e-5, w_threshold=0.0, use_bias=True)\n",
    "sm.threshold_till_dag()\n",
    "\n",
    "from causalnex.plots import plot_structure, NODE_STYLE, EDGE_STYLE\n",
    "viz = plot_structure(\n",
    "    sm, \n",
    "    all_node_attributes=NODE_STYLE.NORMAL,\n",
    "    all_edge_attributes=EDGE_STYLE.NORMAL\n",
    ")\n",
    "\n",
    "\n",
    "viz.show(\"supporting_files/02_boston_housing.html\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "___\n",
    "\n",
    "## Developer Guide\n",
    "\n",
    "The distributions are kept in `causalnex/structure/pytorch/dist_type/`. If you want to read through already implemented distributions, take a look there.\n",
    "\n",
    "### Single Parameter Distribution\n",
    "\n",
    "Contributing new single-parameter distributions is very simple. The steps are as follows:\n",
    "\n",
    "- subclass `causalnex/structure/pytorch/dist_type/_base/DistTypeBase`\n",
    "\n",
    "- fill in the negative-log-likelihood as the \"loss\" method\n",
    "\n",
    "- fill in the inverse link function\n",
    "\n",
    "Each `DistType` class uses the `self.idx` attribute to select the data column which it corresponds to.\n",
    "\n",
    "Multi-parameter distributions will be covered in a separate post.\n",
    "\n",
    "The poisson dist type is shown below as an example:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from causalnex.structure.pytorch.dist_type._base import DistTypeBase\n",
    "\n",
    "\n",
    "class DistTypePoisson(DistTypeBase):\n",
    "    \"\"\" Class defining poisson distribution type functionality \"\"\"\n",
    "\n",
    "    def loss(self, X: torch.Tensor, X_hat: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        https://pytorch.org/docs/master/generated/torch.nn.PoissonNLLLoss.html\n",
    "        Uses the functional implementation of the PoissonNLL class.\n",
    "        Returns the elementwise Poisson Negative Log Likelihood loss.\n",
    "\n",
    "        Args:\n",
    "            X: The original data passed into NOTEARS (i.e. the reconstruction target).\n",
    "\n",
    "            X_hat: The reconstructed data.\n",
    "\n",
    "        Returns:\n",
    "            Scalar pytorch tensor of the reconstruction loss between X and X_hat.\n",
    "        \"\"\"\n",
    "        return nn.functional.poisson_nll_loss(\n",
    "            input=X_hat[:, self.idx],\n",
    "            target=X[:, self.idx],\n",
    "            reduction=\"mean\",\n",
    "            log_input=True,\n",
    "            full=False,\n",
    "        )\n",
    "\n",
    "    def inverse_link_function(self, X_hat: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Exponential inverse link function for poisson data.\n",
    "\n",
    "        Args:\n",
    "            X_hat: Reconstructed data in the latent space.\n",
    "\n",
    "        Returns:\n",
    "            Modified X_hat.\n",
    "            MUST be same shape as passed in data.\n",
    "            Projects the self.idx column from the latent space to the dist_type space.\n",
    "        \"\"\"\n",
    "        X_hat[:, self.idx] = torch.exp(X_hat[:, self.idx])\n",
    "        return X_hat"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multi Parameter Distributions\n",
    "\n",
    "These are a little bit more tricky.\n",
    "\n",
    "Multi-parameter Vector GLMs generally work by expanding the column space. Each column then becomes responsible for fitting a separate parameter. There are generally two types of column expansion:\n",
    "\n",
    "- Column duplication (fit only). These are commonly used for multiparameter distributions, for example fitting the mean and std of a gaussian distribution. The expanded columns are _NOT_ used as features in prediction.\n",
    "\n",
    "- Column expansion (fit and predict). These are used when the column expansion is also used to predict, as in the case of categorical distributions.\n",
    "\n",
    "The below code sample shows a column expansion example. The additional methods which need to be changed by the user are:\n",
    "\n",
    "- get_columns\n",
    "\n",
    "- preprocess_X\n",
    "\n",
    "- preprocess_tabu_edges\n",
    "\n",
    "- preprocess_tabu_nodes\n",
    "\n",
    "- modify h\n",
    "\n",
    "- add_to_node\n",
    "\n",
    "- update_idx_col"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "from typing import Dict, List, Tuple\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "from causalnex.structure.pytorch.dist_type._base import DistTypeBase, ExpandColumnsMixin\n",
    "from causalnex.structure.structuremodel import StructureModel\n",
    "\n",
    "\n",
    "class DistTypeCategorical(ExpandColumnsMixin, DistTypeBase):\n",
    "    \"\"\" Class defining categorical distribution type functionality \"\"\"\n",
    "\n",
    "    # index group of categorical columns\n",
    "    idx_group = None\n",
    "    # column expander for later preprocessing\n",
    "    encoder = None\n",
    "\n",
    "    def get_columns(\n",
    "        self,\n",
    "        X: np.ndarray,\n",
    "    ) -> np.ndarray:\n",
    "        \"\"\"\n",
    "        Gets the column(s) associated with the instantiated DistType.\n",
    "\n",
    "        Args:\n",
    "            X: Full dataset to be selected from.\n",
    "\n",
    "        Returns:\n",
    "            1d or 2d np.ndarray of columns.\n",
    "        \"\"\"\n",
    "        return X[:, self.idx_group]\n",
    "\n",
    "    def preprocess_X(self, X: np.ndarray, fit_transform: bool = True) -> np.ndarray:\n",
    "        \"\"\"\n",
    "        Expands the feature dimension for each categorical column by:\n",
    "        - One hot encode each of the categorical features\n",
    "        - For each feature, get handle on groups of one-hot expanded columns\n",
    "        - Store the handle groups\n",
    "        - Return expanded array\n",
    "        NOTE: the number of expanded columns is EQUAL to the number of classes\n",
    "        for ease of use with the Pytorch loss functions.\n",
    "        This is technically wasteful computationally (only need C-1 columns).\n",
    "\n",
    "        Args:\n",
    "            X: The original passed-in data.\n",
    "\n",
    "            fit_transform: Whether the class first fits\n",
    "            then transforms the data, or just transforms.\n",
    "            Just transforming is used to preprocess new data after the\n",
    "            initial NOTEARS fit.\n",
    "\n",
    "        Returns:\n",
    "            Preprocessed X\n",
    "        \"\"\"\n",
    "        # deepcopy to prevent overwrite errors\n",
    "        X = deepcopy(X)\n",
    "\n",
    "        # fit the OneHotEncoder\n",
    "        if fit_transform:\n",
    "            self.encoder = OneHotEncoder(sparse=False, categories=\"auto\", drop=None)\n",
    "            self.encoder.fit(X[:, [self.idx]])\n",
    "\n",
    "        # expand columns for this feature\n",
    "        expanded_columns = self.encoder.transform(X[:, [self.idx]])\n",
    "        # update the original column with the first expanded column\n",
    "        X[:, self.idx] = expanded_columns[:, 0]\n",
    "        # append the remainder cols to X\n",
    "        X = self._expand_columns(X, expanded_columns[:, 1:])\n",
    "\n",
    "        # update the idx_group with expanded columns\n",
    "        if fit_transform:\n",
    "            self.idx_group = []\n",
    "            # preserve the first column location\n",
    "            self.idx_group.append(self.idx)\n",
    "            # the new cols are appended to the end of X contiguously\n",
    "            n_new_cols = expanded_columns.shape[1] - 1\n",
    "            idx_start = X.shape[1] - n_new_cols\n",
    "            # preserve location of expanded columns\n",
    "            self.idx_group += list(range(idx_start, X.shape[1]))\n",
    "\n",
    "        return X\n",
    "\n",
    "    def preprocess_tabu_edges(\n",
    "        self, tabu_edges: List[Tuple[int, int]]\n",
    "    ) -> List[Tuple[int, int]]:\n",
    "        \"\"\"\n",
    "        Update tabu_edges taking into account expanded columns.\n",
    "\n",
    "        Args:\n",
    "            tabu_edges: The original tabu_edges.\n",
    "\n",
    "        Returns:\n",
    "            Preprocessed tabu_edges.\n",
    "        \"\"\"\n",
    "        return self.update_tabu_edges(\n",
    "            idx_group=self.idx_group, tabu_edges=tabu_edges, tabu_idx_group=True\n",
    "        )\n",
    "\n",
    "    def preprocess_tabu_nodes(self, tabu_nodes: List[int]) -> List[int]:\n",
    "        \"\"\"\n",
    "        Update tabu_nodes taking into account expanded columns.\n",
    "\n",
    "        Args:\n",
    "            tabu_nodes: The original tabu_nodes.\n",
    "\n",
    "        Returns:\n",
    "            Preprocessed tabu_nodes.\n",
    "        \"\"\"\n",
    "        return self.update_tabu_nodes(idx_group=self.idx_group, tabu_nodes=tabu_nodes)\n",
    "\n",
    "    def modify_h(self, square_weight_mat: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Used to prevent spurious cycles between expanded columns and other features.\n",
    "        For example, A_cat1 -> B -> A_cat2 would not be penalized by the h(W) constraint.\n",
    "\n",
    "        This modification solves that by adding the expanded columns of the\n",
    "        squared adjacency matrix onto the original column. This effectively superimposes\n",
    "        All expanded column connections onto a single connection\n",
    "\n",
    "        Args:\n",
    "            square_weight_mat: The squared adjacency matrix used in h(W).\n",
    "\n",
    "        Returns:\n",
    "            The modified W matrix.\n",
    "        \"\"\"\n",
    "        orig_idx = self.idx_group[0]\n",
    "        expand_idx = self.idx_group[1:]\n",
    "\n",
    "        # Add on the edges from expanded nodes.\n",
    "        square_weight_mat[orig_idx, :] = square_weight_mat[orig_idx, :] + torch.sum(\n",
    "            square_weight_mat[expand_idx, :], dim=0\n",
    "        )\n",
    "        # Add on the edges to expanded nodes.\n",
    "        square_weight_mat[:, orig_idx] = square_weight_mat[:, orig_idx] + torch.sum(\n",
    "            square_weight_mat[:, expand_idx], dim=1\n",
    "        )\n",
    "\n",
    "        return square_weight_mat\n",
    "\n",
    "    @staticmethod\n",
    "    def _to_index(X_one_hot: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Recover the numerical columns by argmaxing a one-hot vector.\n",
    "\n",
    "        Args:\n",
    "            X_one_hot: The one-hot tensor to be collapsed.\n",
    "\n",
    "        Returns:\n",
    "            A 1d tensor representing the classes defined by the above one-hot\n",
    "            tensor.\n",
    "        \"\"\"\n",
    "        return torch.argmax(X_one_hot, dim=1)\n",
    "\n",
    "    def add_to_node(self, sm: StructureModel) -> StructureModel:\n",
    "        \"\"\"\n",
    "        Adds self to a node of a structure model corresponding to\n",
    "        all indexes in self.idx_group.\n",
    "\n",
    "        Args:\n",
    "            sm: The input StructureModel\n",
    "\n",
    "        Returns:\n",
    "            Updated StructureModel\n",
    "        \"\"\"\n",
    "        for idx in self.idx_group:\n",
    "            sm.nodes[idx][\"dist_type\"] = self\n",
    "        return sm\n",
    "\n",
    "    def loss(self, X: torch.Tensor, X_hat: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Uses the functional implementation of the CrossEntropyLoss class\n",
    "        https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss.\n",
    "\n",
    "        Returns the mean row wise cross entropy loss for a single group of categorical columns.\n",
    "\n",
    "        NOTE: the pytorch implementation assumes a numeric target input.\n",
    "        Therefore, collapse the one hot columns into a numeric target column.\n",
    "\n",
    "        Args:\n",
    "            X: The original data passed into NOTEARS (i.e. the reconstruction target).\n",
    "\n",
    "            X_hat: The reconstructed data.\n",
    "\n",
    "        Returns:\n",
    "            Scalar pytorch tensor of the reconstruction loss between X and X_hat.\n",
    "        \"\"\"\n",
    "\n",
    "        return nn.functional.cross_entropy(\n",
    "            input=X_hat[:, self.idx_group],\n",
    "            target=self._to_index(X[:, self.idx_group]),\n",
    "            reduction=\"mean\",\n",
    "        )\n",
    "\n",
    "    def inverse_link_function(self, X_hat: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Softmax inverse link function for categorical data.\n",
    "\n",
    "        Args:\n",
    "            X_hat: Reconstructed data in the latent space.\n",
    "\n",
    "        Returns:\n",
    "            Modified X_hat.\n",
    "            MUST be same shape as passed in data.\n",
    "            Projects the self.idx column from the latent space to the dist_type space.\n",
    "        \"\"\"\n",
    "        X_hat[:, self.idx_group] = torch.softmax(X_hat[:, self.idx_group], dim=1)\n",
    "        return X_hat\n",
    "\n",
    "    @staticmethod\n",
    "    def make_node_name(colname: str, catidx: int) -> str:\n",
    "        \"\"\"\n",
    "        Renaming scheme for expanded categorical columns.\n",
    "        NOTE: column is not renamed if catidx is 0.\n",
    "        This is bc original column name needs to stay constant.\n",
    "\n",
    "        Args:\n",
    "            colname: The base column used in the renaming.\n",
    "\n",
    "            catidx: The index of the categorical expansion.\n",
    "\n",
    "        Returns:\n",
    "            Updated column name.\n",
    "        \"\"\"\n",
    "        if catidx:\n",
    "            return f\"{colname}{catidx}\"\n",
    "        return colname\n",
    "\n",
    "    def update_idx_col(self, idx_col: Dict[int, str]) -> Dict[int, str]:\n",
    "        \"\"\"\n",
    "        Expand the named columns to include category names.\n",
    "\n",
    "        Args:\n",
    "            idx_col: The original index to column mapping.\n",
    "\n",
    "        Returns:\n",
    "            Updated index to column mapping.\n",
    "        \"\"\"\n",
    "        new_idx_cols = {}\n",
    "        colname = idx_col.pop(self.idx_group[0])\n",
    "        for catidx, idx in enumerate(self.idx_group):\n",
    "            new_idx_cols[idx] = self.make_node_name(colname, catidx)\n",
    "        return {**idx_col, **new_idx_cols}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "private-causalnex39",
   "language": "python",
   "name": "private-causalnex39"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
